{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8f50a762",
   "metadata": {},
   "source": [
    "## Using LMI Containers with SageMaker Async Endpoints\n",
    "\n",
    "This notebook will demonstrate usage of LMI DLCs to host models on SageMaker Async Inference Endpoints. Support for Async Inference with LMI requires using 0.31.0 container versions or later."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d8df2f",
   "metadata": {},
   "source": [
    "### Install and Update Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "103241eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -U sagemaker boto3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8927fe6",
   "metadata": {},
   "source": [
    "### Create and deploy a Model for Async Inference\n",
    "\n",
    "You need to create an [AsyncInferenceConfig](https://sagemaker.readthedocs.io/en/stable/api/inference/async_inference.html) in order to deploy an async endpoint. In this example, we will be using the default AsyncInferenceConfig, but you are welcome to customize it as needed.\n",
    "\n",
    "This example deploys the [Llama3.1-8b-Instruct](meta-llama/Llama-3.1-8B-Instruct) model. This is a gated model and requires a HuggingFace account that has been granted permissions to the model, and a valid hub access token. If you do not have access to this model, you can use another text generation model. In this example we will use the OpenAI Chat Completions request format, so you need to use a model with a chat template."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2320a52e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker.djl_inference import DJLModel\n",
    "from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig\n",
    "from sagemaker.async_inference.waiter_config import WaiterConfig\n",
    "from sagemaker.session import Session\n",
    "\n",
    "role = sagemaker.get_execution_role()\n",
    "image_uri = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124\"\n",
    "session = Session()\n",
    "\n",
    "model = DJLModel(\n",
    "    image_uri=image_uri,\n",
    "    env = {\n",
    "        \"HF_MODEL_ID\": \"meta-llama/Llama-3.1-8B-Instruct\",\n",
    "        \"HF_TOKEN\": \"<your hub token>\"\n",
    "    },\n",
    "    role=role,\n",
    "    sagemaker_session=session,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57c92166",
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoint_name = sagemaker.utils.name_from_base(\"my-lmi-async-endpoint\")\n",
    "async_inference_config = AsyncInferenceConfig()\n",
    "\n",
    "model.deploy(\n",
    "    initial_instance_count=1,\n",
    "    instance_type=\"ml.g6.12xlarge\",\n",
    "    endpoint_name=endpoint_name,\n",
    "    async_inference_config=async_inference_config,\n",
    "    container_startup_health_check_timeout=2400,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd5dfde6",
   "metadata": {},
   "source": [
    "### Create Sample inputs and upload to s3\n",
    "\n",
    "Async Endpoints are invoked with an s3 object that contains your inference request. We will create two sample inference requests and upload them to s3.\n",
    "\n",
    "Async inference is not compatible with streaming. You cannot specify `\"stream\": true` in the payload."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d985087a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile sample_inputs.json\n",
    "{\n",
    "    \"inputs\": \"Please give me a 10 day itinerary for my trip to New York. Sure, starting on day 1 \",\n",
    "    \"parameters\": {\n",
    "        \"temperature\": 0.6,\n",
    "        \"top_p\": 0.9,\n",
    "        \"max_new_tokens\": 1024\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25c91ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile sample_messages.json\n",
    "{\n",
    "    \"messages\": [\n",
    "        {\"role\": \"user\", \"content\": \"Please give me a 10 day itinerary for my trip to New York. Sure, starting on day 1\"}\n",
    "    ],\n",
    "    \"temperature\": 0.6,\n",
    "    \"top_p\": 0.9,\n",
    "    \"max_tokens\": 1024\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da773200",
   "metadata": {},
   "outputs": [],
   "source": [
    "bucket = session.default_bucket()\n",
    "\n",
    "# Upload the request following the default LMI schema\n",
    "sample_input_path = session.upload_data(\n",
    "    \"sample_inputs.json\",\n",
    "    bucket=bucket,\n",
    "    key_prefix=\"async_lmi_inputs\"\n",
    ")\n",
    "# Upload the request following the OpenAI Chat Completions schema\n",
    "sample_messages_path = session.upload_data(\n",
    "    \"sample_messages.json\",\n",
    "    bucket=bucket,\n",
    "    key_prefix=\"async_lmi_inputs\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25b873d2",
   "metadata": {},
   "source": [
    "### Create the Async Predictor and make inference requests\n",
    "\n",
    "The [AsyncPredictor](https://sagemaker.readthedocs.io/en/stable/api/inference/predictor_async.html) provides utility methods for interacting with the async endpoint and making inference requests.\n",
    "\n",
    "In this example, we'll be using the predict_async method as it is non-blocking. We also specify `initial_args={\"ContentType\": \"application/json\"}` so that the request gets serialized correctly and can be hanlded by the container."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ea48dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = sagemaker.Predictor(endpoint_name=endpoint_name)\n",
    "async_predictor = sagemaker.predictor_async.AsyncPredictor(predictor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7ffaa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "async_response_sample_inputs = async_predictor.predict_async(\n",
    "    input_path=sample_input_path,\n",
    "    initial_args={\"ContentType\": \"application/json\"},\n",
    ")\n",
    "async_response_sample_messages = async_predictor.predict_async(\n",
    "    input_path=sample_messages_path,\n",
    "    initial_args={\"ContentType\": \"application/json\"},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82ddaa7d",
   "metadata": {},
   "source": [
    "### Poll for Inference Completion\n",
    "\n",
    "You can use the [WaiterConfig](https://sagemaker.readthedocs.io/en/stable/api/inference/async_inference.html#sagemaker.async_inference.waiter_config.WaiterConfig) to configure the polling cadence for inference results. We will use the default WaiterConfig in this example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ebffc95",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "waiter_config = WaiterConfig()\n",
    "\n",
    "inputs_result = async_response_sample_inputs.get_result(waiter_config=waiter_config)\n",
    "messages_result = async_response_sample_messages.get_result(waiter_config=waiter_config)\n",
    "\n",
    "print(f\"Result from LMI style request is:\\n {json.loads(inputs_result)}\")\n",
    "print(\"--------------------------\")\n",
    "print(f\"Result from OpenAI style request is:\\n {json.loads(messages_result)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4836bb7b",
   "metadata": {},
   "source": [
    "### Clean up Resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f954a71",
   "metadata": {},
   "outputs": [],
   "source": [
    "async_predictor.delete_endpoint()\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
